import os

from Global_variable_setting import normalization

from Building_dataset import get_dataset
from Model_related_processing import model_train
from Displaying_results import display_results


train_result_dir = os.getcwd() + '/result/training_result'

from Global_variable_setting import dataset_kind
if dataset_kind == 'perfect':
    train_result_dir = train_result_dir + '/good_data_for_test'

# import model_building method
from Global_variable_setting import model_type_index
if model_type_index == 0:
    from Building_vision_transformer import get_vision_transformer
    model = get_vision_transformer()
    if normalization == False:
        training_result_dir_specific = train_result_dir + '/Vision Transformer/unnormalized'
    else:
        training_result_dir_specific = train_result_dir + '/Vision Transformer/normalized'
    experiment_time = len(os.listdir(training_result_dir_specific))
else:
    from Building_CNN import get_cnn
    model = get_cnn()
    if normalization == False:
        training_result_dir_specific = train_result_dir + '/CNN/unnormalized'
    else:
        training_result_dir_specific = train_result_dir + '/CNN/normalized'
    experiment_time = len(os.listdir(training_result_dir_specific))

training_result_dir_result = training_result_dir_specific + '/' + str(experiment_time)

ds_train, ds_test = get_dataset(normalization)
history = model_train(model, ds_train, ds_test, model_type_index, training_result_dir_result)
display_results(history, training_result_dir_result)
